{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 0
   },
   "source": [
    "## 07词的相似性和类比任务"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:06:41.256400Z",
     "iopub.status.busy": "2023-08-18T07:06:41.255749Z",
     "iopub.status.idle": "2023-08-18T07:06:43.288113Z",
     "shell.execute_reply": "2023-08-18T07:06:43.287240Z"
    },
    "origin_pos": 2,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "from torch import nn\n",
    "from d2l import torch as d2l"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 4
   },
   "source": [
    "## 加载预训练词向量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:06:43.292543Z",
     "iopub.status.busy": "2023-08-18T07:06:43.291837Z",
     "iopub.status.idle": "2023-08-18T07:06:43.297097Z",
     "shell.execute_reply": "2023-08-18T07:06:43.296299Z"
    },
    "origin_pos": 5,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "#@save\n",
    "d2l.DATA_HUB['glove.6b.50d'] = (d2l.DATA_URL + 'glove.6B.50d.zip',\n",
    "                                '0b8703943ccdb6eb788e6f091b8946e82231bc4d')\n",
    "\n",
    "#@save\n",
    "d2l.DATA_HUB['glove.6b.100d'] = (d2l.DATA_URL + 'glove.6B.100d.zip',\n",
    "                                 'cd43bfb07e44e6f27cbcc7bc9ae3d80284fdaf5a')\n",
    "\n",
    "#@save\n",
    "d2l.DATA_HUB['glove.42b.300d'] = (d2l.DATA_URL + 'glove.42B.300d.zip',\n",
    "                                  'b5116e234e9eb9076672cfeabf5469f3eec904fa')\n",
    "\n",
    "#@save\n",
    "d2l.DATA_HUB['wiki.en'] = (d2l.DATA_URL + 'wiki.en.zip',\n",
    "                           'c1816da3821ae9f43899be655002f6c723e91b88')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:06:43.300883Z",
     "iopub.status.busy": "2023-08-18T07:06:43.300205Z",
     "iopub.status.idle": "2023-08-18T07:06:43.309328Z",
     "shell.execute_reply": "2023-08-18T07:06:43.308481Z"
    },
    "origin_pos": 7,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "#@save\n",
    "class TokenEmbedding:\n",
    "    \"\"\"GloVe嵌入\"\"\"\n",
    "    def __init__(self, embedding_name):\n",
    "        self.idx_to_token, self.idx_to_vec = self._load_embedding(\n",
    "            embedding_name)\n",
    "        self.unknown_idx = 0\n",
    "        self.token_to_idx = {token: idx for idx, token in\n",
    "                             enumerate(self.idx_to_token)}\n",
    "\n",
    "    def _load_embedding(self, embedding_name):\n",
    "        idx_to_token, idx_to_vec = ['<unk>'], []\n",
    "        data_dir = d2l.download_extract(embedding_name)\n",
    "        # GloVe网站：https://nlp.stanford.edu/projects/glove/\n",
    "        # fastText网站：https://fasttext.cc/\n",
    "        with open(os.path.join(data_dir, 'vec.txt'), 'r') as f:\n",
    "            for line in f:\n",
    "                elems = line.rstrip().split(' ')\n",
    "                token, elems = elems[0], [float(elem) for elem in elems[1:]]\n",
    "                # 跳过标题信息，例如fastText中的首行\n",
    "                if len(elems) > 1:\n",
    "                    idx_to_token.append(token)\n",
    "                    idx_to_vec.append(elems)\n",
    "        idx_to_vec = [[0] * len(idx_to_vec[0])] + idx_to_vec\n",
    "        return idx_to_token, torch.tensor(idx_to_vec)\n",
    "\n",
    "    def __getitem__(self, tokens):\n",
    "        indices = [self.token_to_idx.get(token, self.unknown_idx)\n",
    "                   for token in tokens]\n",
    "        vecs = self.idx_to_vec[torch.tensor(indices)]\n",
    "        return vecs\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.idx_to_token)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:06:43.312986Z",
     "iopub.status.busy": "2023-08-18T07:06:43.312409Z",
     "iopub.status.idle": "2023-08-18T07:06:54.396038Z",
     "shell.execute_reply": "2023-08-18T07:06:54.395176Z"
    },
    "origin_pos": 9,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading ../data/glove.6B.50d.zip from http://d2l-data.s3-accelerate.amazonaws.com/glove.6B.50d.zip...\n"
     ]
    }
   ],
   "source": [
    "glove_6b50d = TokenEmbedding('glove.6b.50d')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:06:54.400162Z",
     "iopub.status.busy": "2023-08-18T07:06:54.399579Z",
     "iopub.status.idle": "2023-08-18T07:06:54.405466Z",
     "shell.execute_reply": "2023-08-18T07:06:54.404676Z"
    },
    "origin_pos": 11,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "400001"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(glove_6b50d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:06:54.408746Z",
     "iopub.status.busy": "2023-08-18T07:06:54.408294Z",
     "iopub.status.idle": "2023-08-18T07:06:54.413468Z",
     "shell.execute_reply": "2023-08-18T07:06:54.412687Z"
    },
    "origin_pos": 13,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(3367, 'beautiful')"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "glove_6b50d.token_to_idx['beautiful'], glove_6b50d.idx_to_token[3367]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:06:54.416901Z",
     "iopub.status.busy": "2023-08-18T07:06:54.416268Z",
     "iopub.status.idle": "2023-08-18T07:06:54.421648Z",
     "shell.execute_reply": "2023-08-18T07:06:54.420466Z"
    },
    "origin_pos": 16,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def knn(W, x, k):\n",
    "    # 增加1e-9以获得数值稳定性\n",
    "    cos = torch.mv(W, x.reshape(-1,)) / (\n",
    "        torch.sqrt(torch.sum(W * W, axis=1) + 1e-9) *\n",
    "        torch.sqrt((x * x).sum()))\n",
    "    _, topk = torch.topk(cos, k=k)\n",
    "    return topk, [cos[int(i)] for i in topk]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:06:54.425376Z",
     "iopub.status.busy": "2023-08-18T07:06:54.424618Z",
     "iopub.status.idle": "2023-08-18T07:06:54.430025Z",
     "shell.execute_reply": "2023-08-18T07:06:54.428981Z"
    },
    "origin_pos": 19,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def get_similar_tokens(query_token, k, embed):\n",
    "    topk, cos = knn(embed.idx_to_vec, embed[[query_token]], k + 1)\n",
    "    for i, c in zip(topk[1:], cos[1:]):  # 排除输入词\n",
    "        print(f'{embed.idx_to_token[int(i)]}：cosine相似度={float(c):.3f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:06:54.433258Z",
     "iopub.status.busy": "2023-08-18T07:06:54.432943Z",
     "iopub.status.idle": "2023-08-18T07:06:54.481827Z",
     "shell.execute_reply": "2023-08-18T07:06:54.480628Z"
    },
    "origin_pos": 21,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "chips：cosine相似度=0.856\n",
      "intel：cosine相似度=0.749\n",
      "electronics：cosine相似度=0.749\n"
     ]
    }
   ],
   "source": [
    "get_similar_tokens('chip', 3, glove_6b50d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:06:54.486458Z",
     "iopub.status.busy": "2023-08-18T07:06:54.485962Z",
     "iopub.status.idle": "2023-08-18T07:06:54.508991Z",
     "shell.execute_reply": "2023-08-18T07:06:54.507938Z"
    },
    "origin_pos": 23,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "babies：cosine相似度=0.839\n",
      "boy：cosine相似度=0.800\n",
      "girl：cosine相似度=0.792\n"
     ]
    }
   ],
   "source": [
    "get_similar_tokens('baby', 3, glove_6b50d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:06:54.513356Z",
     "iopub.status.busy": "2023-08-18T07:06:54.512976Z",
     "iopub.status.idle": "2023-08-18T07:06:54.534489Z",
     "shell.execute_reply": "2023-08-18T07:06:54.533425Z"
    },
    "origin_pos": 24,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "lovely：cosine相似度=0.921\n",
      "gorgeous：cosine相似度=0.893\n",
      "wonderful：cosine相似度=0.830\n"
     ]
    }
   ],
   "source": [
    "get_similar_tokens('beautiful', 3, glove_6b50d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:06:54.539108Z",
     "iopub.status.busy": "2023-08-18T07:06:54.538593Z",
     "iopub.status.idle": "2023-08-18T07:06:54.544150Z",
     "shell.execute_reply": "2023-08-18T07:06:54.543191Z"
    },
    "origin_pos": 26,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def get_analogy(token_a, token_b, token_c, embed):\n",
    "    vecs = embed[[token_a, token_b, token_c]]\n",
    "    x = vecs[1] - vecs[0] + vecs[2]\n",
    "    topk, cos = knn(embed.idx_to_vec, x, 1)\n",
    "    return embed.idx_to_token[int(topk[0])]  # 删除未知词"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:06:54.548236Z",
     "iopub.status.busy": "2023-08-18T07:06:54.547963Z",
     "iopub.status.idle": "2023-08-18T07:06:54.569097Z",
     "shell.execute_reply": "2023-08-18T07:06:54.568018Z"
    },
    "origin_pos": 28,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'daughter'"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_analogy('man', 'woman', 'son', glove_6b50d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:06:54.573551Z",
     "iopub.status.busy": "2023-08-18T07:06:54.573270Z",
     "iopub.status.idle": "2023-08-18T07:06:54.595104Z",
     "shell.execute_reply": "2023-08-18T07:06:54.594092Z"
    },
    "origin_pos": 30,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'japan'"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_analogy('beijing', 'china', 'tokyo', glove_6b50d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:06:54.599698Z",
     "iopub.status.busy": "2023-08-18T07:06:54.599313Z",
     "iopub.status.idle": "2023-08-18T07:06:54.621533Z",
     "shell.execute_reply": "2023-08-18T07:06:54.620486Z"
    },
    "origin_pos": 32,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'biggest'"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_analogy('bad', 'worst', 'big', glove_6b50d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:06:54.626086Z",
     "iopub.status.busy": "2023-08-18T07:06:54.625554Z",
     "iopub.status.idle": "2023-08-18T07:06:54.647570Z",
     "shell.execute_reply": "2023-08-18T07:06:54.646604Z"
    },
    "origin_pos": 34,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'went'"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_analogy('do', 'did', 'go', glove_6b50d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "required_libs": []
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
